package mybatis.generator.support.parse;

import mybatis.generator.ConflictAction;
import mybatis.generator.FilterType;
import mybatis.generator.annotation.*;
import mybatis.generator.ext.AdditionalFilter;
import mybatis.generator.ext.InterceptorContext;
import mybatis.generator.support.Reflection;
import mybatis.generator.support.StringUtils;
import org.apache.ibatis.session.RowBounds;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationAttributes;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.*;

/**
 * Created by huangdachao on 2018/6/15 16:06.
 */
public class FilterMeta {
    private TableColumn tableColumn;
    private int argIndex;
    private String field;
    private FilterType type = FilterType.eq;
    private String placeholder;
    private Object value;
    private int order = Filter.DEFAULT_ORDER;
    private ConflictAction conflict = ConflictAction.coexist;
    private Set<Table> referTables = new HashSet<>(); // 解析placeholder时，引用的其它表

    public FilterMeta() { }

    public static void margeFromInterceptorContext(InterceptorContext context, Map<String, Object> additionalParams, Map<TableColumn, List<FilterMeta>> filtered) {
        for (int i = 0; i < context.getAdditionalColumnValues().size(); i++) {
            AdditionalFilter filter = context.getAdditionalFilters().get(i);
            String key = "__f" + i;
            additionalParams.put(key, filter.getVal());

            TableColumn tc;
            EntityMeta em = EntityMeta.get(context.getEntityClass());
            if (!StringUtils.isEmpty(filter.getColumn())) {
                tc = new TableColumn(filter.getTable(), filter.getTblAlias(), filter.getColumn());
            } else if (!StringUtils.isEmpty(filter.getEntityField())) {
                tc = em.getFieldsMap().get(filter.getEntityField());
                if (tc == null) {
                    throw new RuntimeException("没有找到映射字段：" + filter);
                }
            } else {
                throw new RuntimeException("无效的AdditionalFilter：" + filter);
            }

            FilterMeta fm = new FilterMeta();
            fm.tableColumn = tc;
            fm.type = filter.getType();
            fm.value = filter.getVal();
            fm.order = filter.getOrder();
            fm.conflict = filter.getConflict();
            fm.referTables.add(tc.toTable());
            if (!StringUtils.isEmpty(filter.getPlaceholder())) {
                fm.placeholder = filter.getPlaceholder().replaceAll("\\$@", key);
                fm.referTables.addAll(Table.parse(fm.placeholder, em));
            } else {
                fm.placeholder = key;
            }
            merge(filtered, fm);
        }
    }

    public static void mergeFromFilters(Filter[] array, EntityMeta em, Map<TableColumn, List<FilterMeta>> filtered) {
        Arrays.stream(array).forEach(filter -> {
            if (filter.ignore()) {
                return;
            }

            TableColumn tc;
            if (!filter.column().isEmpty()) {
                Table table = filter.table().isEmpty() ? em.getTable() : new Table(filter.table(), filter.tblAlias());
                tc = new TableColumn(table.table, table.alias, filter.column());
            } else if (!filter.field().isEmpty()) {
                tc = em.getFieldsMap().get(filter.field());
                if (tc == null) {
                    throw new RuntimeException("无法解析的字段：" + em.getEntityClass().getName() + "." + filter.field());
                }
            } else {
                throw new RuntimeException("无效的过滤条件@Filter：" + filter + "。实体类：" + em.getEntityClass());
            }

            if (filter.placeholder().isEmpty()) {
                throw new RuntimeException("请指定@Filter的placeholder属性：" + filter + "。实体类：" + em.getEntityClass());
            }

            if (filtered.get(tc) != null && filter.conflict() == ConflictAction.discard) {
                return;
            }

            FilterMeta fm = new FilterMeta();
            fm.setTableColumn(tc);
            fm.setPlaceholder(filter.placeholder());
            fm.setType(filter.type());
            fm.setOrder(filter.order());
            fm.setConflict(filter.conflict());
            fm.referTables.add(tc.toTable());
            fm.referTables.addAll(Table.parse(filter.placeholder(), em));

            merge(filtered, fm);
        });
    }

    public static Map<TableColumn, List<FilterMeta>> parseArgs(EntityMeta em, Method method, Object[] args, PageAndSort pas) {
        Map<TableColumn, List<FilterMeta>> map = new HashMap<>();
        Class<?>[] paramTypes = method.getParameterTypes();
        Annotation[][] ann = method.getParameterAnnotations();

        boolean processed = false;
        for (int i = 0, argIndex = 0; i < ann.length; i++, argIndex++) {
            if (args[argIndex] == null) {
                continue;
            }
            if (paramTypes[i] == RowBounds.class) {
                argIndex -= 1; // 参数索引回退1
                continue;
            }

            for (int j = 0; j < ann[i].length; j++) {
                Class<?> annType = ann[i][j].annotationType();
                if (annType == Limit.class) {
                    if (pas != null && !pas.isSelectOne() && !pas.isSelectCount()) {
                        pas.setLimit(parseLong(args[argIndex], method));
                    }
                    processed = true;
                } else if (annType == Offset.class) {
                    if (pas != null && !pas.isSelectOne() && !pas.isSelectCount()) {
                        pas.setOffset(parseLong(args[argIndex], method));
                    }
                    processed = true;
                } else if (annType == Sort.class) {
                    if (pas != null && !pas.isSelectCount()) {
                        if (args[argIndex].getClass() != String.class) {
                            throw new RuntimeException("排序字段必须为String类型，"
                                + args[argIndex].getClass() + "：" + Reflection.formatMethod(method));
                        }
                        pas.setSort(PageAndSort.SortMeta.parseSort(em, (String) args[argIndex]));
                    }
                    processed = true;
                } else if (annType == IdParam.class) {
                    margeIdParam(map, em, method, i, args[argIndex], (IdParam) ann[i][j]);
                    processed = true;
                } else if (annType == FilterParam.class) {
                    FilterParam f = (FilterParam) ann[i][j];
                    if (f.ignoreOnZero() && Reflection.isNumberZero(args[argIndex])) {
                        continue;
                    }
                    mergeParam(map, em, method, i, args[argIndex], f, pas);
                    processed = true;
                } else {
                    continue;
                }

                break;
            }
        }

        if (!processed && args.length == 1 && args[0] != null && method.getAnnotation(ExecuteUpdate.class) == null) { // 默认认为是Query bean
            mergeParam(map, em, method, 0, args[0], null, pas);
        }
        return map;
    }

    public static void merge(Map<TableColumn, List<FilterMeta>> filtered, FilterMeta meta) {
        List<FilterMeta> list = filtered.computeIfAbsent(meta.getTableColumn(), key -> new ArrayList<>());
        if (list.isEmpty()) {
            list.add(meta);
            return;
        }

        if (meta.getConflict() == ConflictAction.coexist) {
            list.removeIf(fm -> fm.getConflict() == ConflictAction.discard);
            list.add(meta);
        } else if (meta.getConflict() == ConflictAction.override) {
            list.removeIf(fm -> fm.getConflict() == ConflictAction.discard || fm.getConflict() == ConflictAction.override);
            list.add(meta);
        }
    }

    @SuppressWarnings("unchecked")
    private static void mergeParam(Map<TableColumn, List<FilterMeta>> filtered, EntityMeta em, Method method, int argIndex, Object arg, FilterParam param, PageAndSort pas) {
        AnnotationAttributes attr = AnnotatedElementUtils.getMergedAnnotationAttributes(AnnotatedElementUtils.forAnnotations(param), FilterParam.class);
        FilterMeta fm = new FilterMeta();
        fm.argIndex = argIndex;

        if (param != null && !StringUtils.isEmpty(attr.getString("field"))) {
            FieldMeta fdm = em.getFieldMeta().get(em.getTableColumn(attr.getString("field")));
            if (fdm == null) {
                throw new RuntimeException("无法解析的字段：" + attr.getString("field") + "。方法：" + Reflection.formatMethod(method));
            }
            fm.setTableColumn(fdm.getTableColumn());
        } else if (param != null && !StringUtils.isEmpty(param.column())) {
            Table table = StringUtils.isEmpty(param.table()) ? em.getTable() : new Table(param.table(), param.tblAlias());
            fm.setTableColumn(new TableColumn(table, param.column()));
        } else if (arg instanceof Map) { // process Map parameter
            HashMap<String, Object> map = (HashMap<String, Object>) arg;
            map.keySet().forEach(key -> {
                TableColumn tc = em.getFieldsMap().get(key);
                FilterMeta m = new FilterMeta();
                m.field = key;
                m.setTableColumn(tc);
                m.setArgIndex(argIndex);
                m.setType(FilterType.eq);
                merge(filtered, m);
            });
            return;
        } else { // filter bean
            Reflection.resolveProperties(arg.getClass()).forEach(member -> {
                Object val = Reflection.getPropertyValue(arg, member);
                if (val != null) {
                    Annotation[] anns = Reflection.getPropertyAnnotations(member);
                    for (int i = 0; i < anns.length; i++) {
                        Class<?> annType = anns[i].annotationType();
                        if (pas != null && !pas.isSelectCount() && !pas.isSelectOne()) {
                            if (annType == Limit.class) {
                                pas.setLimit(parseLong(val, method));
                                return;
                            } else if (annType == Offset.class) {
                                pas.setOffset(parseLong(val, method));
                                return;
                            }
                        }

                        if (pas!= null && !pas.isSelectCount()) {
                            if (annType == Sort.class) {
                                if (val.getClass() != String.class) {
                                    throw new RuntimeException("排序字段必须为String类型，"
                                        + val.getClass() + "：" + Reflection.formatMethod(method));
                                }
                                pas.setSort(PageAndSort.SortMeta.parseSort(em, (String) val));
                                return;
                            }
                        }
                    }

                    Filter filter = Reflection.getPropertyAnnotation(member, Filter.class);
                    if (filter != null && filter.ignore()) {
                        return;
                    } else if (filter != null && filter.ignoreOnZero()
                        && Reflection.isNumberZero(val)) {
                        return;
                    }
                    mergeProperty(filtered, em, method, argIndex, val, member.getName(), anns);
                }
            });
            return;
        }

        fm.type = param.type();
        if (!StringUtils.isEmpty(param.placeholder())) {
            fm.placeholder = param.placeholder().replace("$@", method.getParameterCount() > 1 ? "arg" + argIndex + "." : "param1");
            fm.referTables.addAll(Table.parse(fm.placeholder, em));
        }
        fm.order = param.order();
        fm.setConflict(param.conflict());
        fm.value = arg;
        merge(filtered, fm);
    }

    private static void margeIdParam(Map<TableColumn, List<FilterMeta>> filtered, EntityMeta em, Method method, int argIndex, Object arg, IdParam param) {
        if (StringUtils.isEmpty(param.value())) {
            List<TableColumn> tableColumns = em.getPrimaryColumns();
            tableColumns.forEach(tc -> {
                FilterMeta fm = new FilterMeta();
                fm.tableColumn = tc;
                fm.argIndex = argIndex;
                if (tableColumns.size() > 1) {
                    fm.field = em.getFieldMeta().get(tc).getField().getName();
                }
                fm.value = arg;
                if (arg instanceof Collection) {
                    fm.type = FilterType.in;
                    fm.placeholder = method.getParameterCount() > 1 ? null : "list";
                } else {
                    fm.type = FilterType.eq;
                }
                merge(filtered, fm);
            });
        } else {
            FilterMeta fm = new FilterMeta();
            TableColumn tc = em.getTableColumn(param.value());
            if (tc == null) {
                throw new RuntimeException("没有找到IdParam指定的字段，" + param.value() + "："
                    + Reflection.formatMethod(method));
            }
            fm.tableColumn = tc;
            fm.argIndex = argIndex;
            fm.value = arg;
            if (arg instanceof Collection) {
                fm.type = FilterType.in;
                fm.placeholder = method.getParameterCount() > 1 ? null : "list";
            } else {
                fm.type = FilterType.eq;
            }
            merge(filtered, fm);
        }
    }

    private static void mergeProperty(Map<TableColumn, List<FilterMeta>> filtered, EntityMeta em, Method method, int argIndex, Object arg, String propName, Annotation[] anns) {
        FilterMeta fm = new FilterMeta();
        fm.field = propName;
        fm.argIndex = argIndex;
        fm.value = arg;
        TableColumn tableColumn = new TableColumn();

        Filter filter = Reflection.findAnnotation(anns, Filter.class);
        if (filter != null) {
            if (!filter.column().isEmpty() || !filter.table().isEmpty()) {
                tableColumn.table = filter.table();
                tableColumn.tblAlias = filter.tblAlias();
                tableColumn.column = filter.column();
                if (tableColumn.table.isEmpty()) {
                    tableColumn.table = em.getTable().table;
                }
            } else if (!filter.field().isEmpty()) {
                tableColumn = em.getTableColumn(filter.field());
                if (tableColumn == null) {
                    throw new RuntimeException("无法解析的字段：" + em.getEntityClass().getName() + "." + filter.field());
                }
            }
            fm.type = filter.type();
            fm.order = filter.order();
            fm.placeholder = filter.placeholder();
        } else {
            fm.type = FilterType.eq;
        }
        if (!StringUtils.isEmpty(fm.placeholder)) {
            String argName = method.getParameterCount() > 1 ? "arg" + argIndex + "." : "";
            fm.placeholder = fm.placeholder.replace("$@", argName + propName);
            fm.referTables.addAll(Table.parse(fm.placeholder, em));
        }

        if (StringUtils.isEmpty(tableColumn.table)) {
            tableColumn.table = em.getTable().table;
            tableColumn.tblAlias = em.getTable().alias;
        } else if (StringUtils.isEmpty(tableColumn.column)) {
            tableColumn.column = StringUtils.snakeCase(propName);
        }

        if (StringUtils.isEmpty(tableColumn.column)) {
            tableColumn = em.getFieldsMap().get(propName);
            if (tableColumn == null) {
                if (filter != null) {
                    throw new IllegalStateException("没有找到对应的数据表字段"
                        + em.getEntityClass().getName() + "." + propName);
                } else {
                    return;
                }
            }
        }
        fm.setTableColumn(tableColumn);
        merge(filtered, fm);
    }

    private static Long parseLong(Object arg, Method method) {
        Class<?> argType = arg.getClass();
        if (argType == Long.class) {
            return (Long) (arg);
        } else if (argType == Integer.class) {
            return Long.valueOf((Integer) arg);
        } else {
            throw new RuntimeException("不支持的参数类型，可选类型：int, long, Integer, Long"
                + argType.getName() + "：" + Reflection.formatMethod(method));
        }
    }

    public TableColumn getTableColumn() {
        return tableColumn;
    }

    public void setTableColumn(TableColumn tableColumn) {
        this.referTables.add(tableColumn.toTable());
        this.tableColumn = tableColumn;
    }

    public int getArgIndex() {
        return argIndex;
    }

    public void setArgIndex(int argIndex) {
        this.argIndex = argIndex;
    }

    public String getField() {
        return field;
    }

    public void setField(String field) {
        this.field = field;
    }

    public FilterType getType() {
        return type;
    }

    public void setType(FilterType type) {
        this.type = type;
    }

    public String getPlaceholder() {
        return placeholder;
    }

    public void setPlaceholder(String placeholder) {
        this.placeholder = placeholder;
    }

    public Object getValue() {
        return value;
    }

    public void setValue(Object value) {
        this.value = value;
    }

    public int getOrder() {
        return order;
    }

    public void setOrder(int order) {
        this.order = order;
    }

    public ConflictAction getConflict() {
        return conflict;
    }

    public void setConflict(ConflictAction conflict) {
        this.conflict = conflict;
    }

    public Set<Table> getReferTables() {
        return referTables;
    }

    public void setReferTables(Set<Table> referTables) {
        this.referTables = referTables;
    }
}
